{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using TensorFlow version: 2.0.0-beta1\n",
      "GPU Available: False\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "from src.lstm import AttentionLSTM\n",
    "from src.metrics import exact_match_metric\n",
    "from src.callbacks import NValidationSetsCallback\n",
    "from src.generators import DataGeneratorAttention\n",
    "from src.utils import get_sequence_data\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "print(f\"Using TensorFlow version: {tf.__version__}\")\n",
    "print(f\"GPU Available: {tf.test.is_gpu_available()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "SETTINGS = Path('../settings/')\n",
    "DATA = Path('../data/processed/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'math_module': 'arithmetic__add_sub',\n",
       " 'train_level': '*',\n",
       " 'batch_size': 1024,\n",
       " 'thinking_steps': 16,\n",
       " 'epochs': 1,\n",
       " 'num_encoder_units': 512,\n",
       " 'num_decoder_units': 2048,\n",
       " 'embedding_dim': 2048,\n",
       " 'save_path': '../data/',\n",
       " 'data_path': '../data/'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "settings_path = Path(SETTINGS / \"settings_local.json\")\n",
    "\n",
    "with open(str(settings_path), \"r\") as file:\n",
    "    settings_dict = json.load(file)\n",
    "\n",
    "settings_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen_pars, input_texts, target_texts = get_sequence_data(settings_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training samples: 1599998\n"
     ]
    }
   ],
   "source": [
    "print('Number of training samples:', len(input_texts['train']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of validation samples: 10000\n"
     ]
    }
   ],
   "source": [
    "print('Number of validation samples:', len(input_texts['interpolate']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_generator = DataGeneratorAttention(\n",
    "    input_texts=input_texts[\"train\"],\n",
    "    target_texts=target_texts[\"train\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "validation_generator = DataGeneratorAttention(\n",
    "    input_texts=input_texts[\"valid\"],\n",
    "    target_texts=target_texts[\"valid\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "interpolate_generator = DataGeneratorAttention(\n",
    "    input_texts=input_texts[\"interpolate\"],\n",
    "    target_texts=target_texts[\"interpolate\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "extrapolate_generator = DataGeneratorAttention(\n",
    "    input_texts=input_texts[\"extrapolate\"],\n",
    "    target_texts=target_texts[\"extrapolate\"],\n",
    "    **data_gen_pars\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0921 18:33:44.146364 4557796800 deprecation.py:323] From /Users/lewtun/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:3868: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_2 (InputLayer)            [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           (None, None, 2048)   67584       input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "embedding_1 (Embedding)         (None, None, 2048)   28672       input_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional (Bidirectional)   (None, None, 1024)   10489856    embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lstm_1 (LSTM)                   (None, None, 2048)   33562624    embedding_1[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, None, 2048)   2099200     bidirectional[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dot (Dot)                       (None, None, None)   0           lstm_1[0][0]                     \n",
      "                                                                 dense[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "attention (Activation)          (None, None, None)   0           dot[0][0]                        \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, None, 2048)   0           attention[0][0]                  \n",
      "                                                                 dense[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "concatenate (Concatenate)       (None, None, 4096)   0           dot_1[0][0]                      \n",
      "                                                                 lstm_1[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "time_distributed (TimeDistribut (None, None, 2048)   8390656     concatenate[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "time_distributed_1 (TimeDistrib (None, None, 13)     26637       time_distributed[0][0]           \n",
      "==================================================================================================\n",
      "Total params: 54,665,229\n",
      "Trainable params: 54,665,229\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "lstm = AttentionLSTM(\n",
    "    data_gen_pars[\"num_encoder_tokens\"],\n",
    "    data_gen_pars[\"num_decoder_tokens\"],\n",
    "    data_gen_pars[\"max_encoder_seq_length\"],\n",
    "    data_gen_pars[\"max_decoder_seq_length\"],\n",
    "    settings_dict[\"num_encoder_units\"],\n",
    "    settings_dict[\"num_decoder_units\"],\n",
    "    settings_dict[\"embedding_dim\"],\n",
    ")\n",
    "model = lstm.get_model()\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "adam = Adam(\n",
    "    lr=6e-4,\n",
    "    beta_1=0.9,\n",
    "    beta_2=0.995,\n",
    "    epsilon=1e-9,\n",
    "    decay=0.0,\n",
    "    amsgrad=False,\n",
    "    clipnorm=0.1,\n",
    ")\n",
    "model.compile(\n",
    "    optimizer=adam, loss=\"categorical_crossentropy\", metrics=[exact_match_metric]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_dict = {\n",
    "    'validation':validation_generator,\n",
    "    'interpolation': interpolate_generator,\n",
    "    'extrapolation': extrapolate_generator\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = NValidationSetsCallback(valid_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# directory where the checkpoints will be saved\n",
    "checkpoint_dir = settings_dict[\"save_path\"] + \"training_checkpoints\"\n",
    "# name of the checkpoint files\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
    "\n",
    "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_prefix, save_weights_only=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_hist = model.fit_generator(\n",
    "    training_generator,\n",
    "    epochs=settings_dict[\"epochs\"],\n",
    "    callbacks=[history, checkpoint_callback],\n",
    "    verbose=1,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
